--- Code generation for type classes and instances
module frege.compiler.gen.java.InstanceCode where

import frege.Prelude hiding (<+>)

import Lib.PP (text, <+>, </>, <+/>, <>)
import Data.TreeMap as Map(values, lookup, delete, insert, TreeMap)
import Data.List(zip4)

import Compiler.Utilities(findC, findV, forceTau, returnType)
import Compiler.Javatypes(subTypeOf)

import Compiler.enums.Flags(TRACEG)
import Compiler.enums.TokenID(VARID)

import Compiler.types.AbstractJava
import Compiler.types.Symbols
import Compiler.types.Global
import Compiler.types.Types(Ctx, TVar, Tau, SigmaT, TauT, RhoT, Kind, KindT, pSigma)
import Compiler.types.Expression
import Compiler.types.QNames(QName)
import Compiler.types.Packs(pPreludeList)
import Compiler.types.JNames(JName)

import Compiler.common.Errors as E()
import Compiler.common.Types  as CT(substSigma, substRho, substTau, tauRho, 
                                    sigmaKind, tauKind)
import Compiler.common.Binders(allBinders)
import Compiler.common.SymbolTable(changeSym)
import Compiler.common.JavaName(symJavaName, javaName)

import Compiler.classes.Nice (nice, nicer, nicest)

import Compiler.tc.Util(sameCtx)

import Compiler.gen.java.Common
import Compiler.gen.java.Bindings(newBind, Binding)
import Compiler.gen.java.VarCode(varCode, compiling, genExpression, genExpr)
-- import Compiler.gen.java.DataCode(callMethod, asThunkMethod)

{--
    Code for type classes

    - interface @CName@<_a_> where _a_ is the class variable.
    - For each class method, an abstract method is generated.

    > ListEmpty, ListMonoid, ListSemigroup, ListView, ListSource

    This is in order to handle 'String', which is not really higher kinded.

    In corresponding java type, all applications of the type variable are just replaced by the
    type variable itself.

    We know that all those functions have the property that the class variable is applied
    to exactly one and the same type variable (i.e. the element type, which is never
    changed by the functions). Hence we have no conflicts here.

    Those functions never have the 'RGeneric' property (as the "original" type variable 
    is replaced). Hence, the function 'uncons' looks like:

    > class ListView c where
    > uncons :: c e -> Maybe (e, c e)
    > fake_uncons :: c -> Maybe (e, c)
    > 
    >
    > interface CListView<C> where
    > public<E> TMaybe<TTuple2<E,C>> uncons(final C arg$1);

    Note that this way, the element type becomes a phantom type, which it is 
    in the case of 'String', and doesn't matter for the other types.

    *IMPLEMENTATION RESTRICTION* Polymorphic functions using one of the special classes
    can still not be instantiated at type 'String'. We use this mechanism only to overload
    the class operations, like (++) and 'length'. 
-}

classCode ∷ Symbol → StG [JDecl]



classCode  (sym@SymC{tau = TVar{var,kind}}) = do           -- type class
     g <- getST
     let vals = values sym.env
         special = isSpecialClass sym
     abstrFuns ← mapSt (abstractFun sym) vals
     let name  = (symJavaName g sym).base
         -- this  = Constr (JName "" name) gargs
         gvars = if special 
                    then [targ g var KType] 
                    else if isArrayClass sym
                        then [JTVar{var="data", bounds=UNBOUNDED}, targ g var kind]
                        else [targ g var kind]
         gargs = [TArg var]
         superclasses = [ Ref (javaName g nm) gx | nm <- sym.supers,
                            let { gx = if isArrayClassName nm 
                                        then TArg "data":gargs 
                                        else gargs }]
         result = JInterface{attr = attrs [JPublic], 
                            name, gvars, 
                            implement = superclasses, 
                            defs = concat abstrFuns} 
     stio [JComment (nice sym g), result]

--- If given something else than a type class this is a fatal compiler error
classCode sym = do
    g ← getST
    E.fatal sym.pos (
            text "classCode: argument is "
            <+> text (nice sym g) 
        ) 

--- Prepare abstract functions of special classes for code generation
--- Returns the compiler state prior to this action, which must be restored afterwards.
lowerKindSpecialClasses = do
    g ← getST
    let items = [ (c,v) | n ← specialClassNames, 
                            c ← g.findit (TName pPreludeList n),
                            v ← Map.values c.env ]
    mapM_ (uncurry lowerKindAbstractFun) items
    return g

lowerKindAbstractFun ∷ Symbol → Symbol → StG ()
lowerKindAbstractFun symc sym = do
        let classvar = symc.tau.var
            newsym   = sym.{typ <- lowerKind classvar}
        changeSym newsym
        -- force syminfo to regenerate information, if already present
        changeST Global.{gen ← _.{symi8 ← delete sym}}
        return ()
    where
        lowerKind cv sigma = sigma.{rho <- lowerRhoKind cv}
        lowerRhoKind cv (rho@RhoFun{}) = rho.{sigma ← lowerKind cv, rho ← lowerRhoKind cv}
        lowerRhoKind cv (rho@RhoTau{}) = rho.{tau ← lowerTauKind cv}
        lowerTauKind cv (TSig s)                =  TSig . lowerKind cv $ s
        lowerTauKind cv (app@TApp a b) = case app.flat of
            TVar{pos, kind, var}:_ | var == cv  =  TVar{pos, kind=KType, var} 
            other                               =  TApp (lowerTauKind cv a) (lowerTauKind cv b)
        lowerTauKind cv t                       =  t
 



--- declare abstract class Member function
abstractFun ∷ Symbol → Symbol → StG [JDecl]
abstractFun symc (sym@SymV{}) = do
    g <- getST
    si <- symInfo sym
    let !classCtx = Ctx {pos=symc.pos, 
                        cname = Symbol.name symc, 
                        tau = Symbol.tau symc }
        !ctxs = filter (not . sameCtx classCtx) sym.typ.rho.context
        -- special = isSpecialClass symc
        arrays  = isArrayClass symc -- are we compiling one of the array classes?
        gvars   = targs g sym.typ.{bound ← filter ((!= symc.tau.var) . _.var)}

 
    let formalctxs = zipWith (constraintArg g)  ctxs  (getCtxs g)
        lazysi = si.{argJTs <- map adapt}
        adapt (Lazy t) = adapt t
        adapt Nativ{typ="[]"} 
            | arrays = Lazy (TArg "data")
        adapt Func{gargs=[a,Nativ{typ="[]"}]} 
            | arrays = Func{gargs = [a, TArg "data"]}
        adapt Ref{jname,gargs=[a,Nativ{typ="[]"}]} 
            | arrays = Ref{jname, gargs = [a, TArg "data"]}
        adapt t = lazy t
        formalargs  = argDefs attrFinal lazysi (getArgs g)
    let !result = JMethod {attr = attrs [JPublic],
                             gvars,
                             jtype = (strict . adapt . tauJT g . fst . returnType . _.{context=ctxs} . _.rho) sym.typ,    
                             name = latinF ++ (symJavaName g sym).base,
                             args = formalctxs ++ formalargs,
                             body = JEmpty}
    pure [JComment ((nicer sym g) ++ " :: " ++ nicer sym.typ g), result]

abstractFun symc symx = do
    g ← getST
    E.fatal symx.pos (
        text "abstractFun: argument is "
            <+> text (nice symx g)
            <+> text " for "
            <+> text (nice symc g) 
        ) 




{--
     Code for instances

    > instance (Pre1 x, Pre2 y) => C (T x y)

    Compiles to a class that implements the interface generated for the class (in 'classCode').

    If there are constraints, the relevant instances must be passed on 
    construction of this one, otherwise, we have a singleton class, i.e.

    > instance Eq Int where ....
    > instance Eq a => Eq (Maybe a) where ...

    becomes

    > class Eq_Int implements CEq<Integer> {
    >        final public Eq_Int it = new Eq_Int();       // singleton
    >        ...
    > }
    > class Eq_Maybe<A> implements CEq<TMaybe<A>> {
    >
    >     public Eq_Maybe(CEq<A> ctx) { ... }
    > }  
-}
instanceCode (sym@SymI {sid}) = do             -- instance definition
     g <- getST
     csym <- findC sym.clas
 
     let classes = sym.clas:csym.supers
         special = isSpecialClass csym
         -- the functions we must provide in the instance
         superMethods = [ m.name.base | c <- classes,
                                   SymC{env} <- g.findit c,
                                   m@SymV{}  <- values env ]
         -- links in types that point to instance members of this class and its superclasses
         -- The goal is to have (links to) implementations of all super class methods. 
         methods2 = case instTSym (Symbol.typ sym) g of
              Just (tsym@SymT {pos}) -> [ alias |
                                SymL {name, alias} <- values tsym.env, alias.{tynm?},    -- links
                                alias `notElem` methods1,                 -- avoid duplicates
                                alias.base `elem` superMethods,           -- mentioning one of our methods
                                name.base `notElem` map QName.base methods1,
                                SymI {clas} <- g.findit alias.tynm, -- pointing to an instance
                                SymC {supers} <- g.findit clas,     -- of a class that is in our hierarchy
                                clas `elem` classes || any (`elem` classes) supers]
              _ -> error "unexpected result from instTSym"
         methods1 = map Symbol.name (values sym.env)
         -- methods of super classes that are implemented in the type itself
         methods3 = case instTSym (Symbol.typ sym) g of
            Just (tsym@SymT {pos}) -> [ sym.name |
                                 sym  <- values tsym.env,
                                 sym.name.base `elem` superMethods,
                                 sym.name.base `notElem` methods] where
                        methods = map QName.base (methods1++methods2)
            _ -> error "unexpected result from instTSym" 
         methods  = methods1 ++ methods2 ++ methods3
 
     let vals = values sym.env
 
     let constraints = zipWith (constraintDef g)  sym.typ.rho.context (getCtxs g)
         constrargs  = zipWith (constraintArg g)  sym.typ.rho.context (getArgs g)
 
 
     let instName = symJavaName g sym
         instjt   = boxed (rhoJT g sym.typ.rho.{context=[]})
         array    = Nativ{typ="[]", gargs=[strict instjt], generic=false}
         rawinst  = rawType instjt
         jtype = Ref instName []
         etype = Ref (symJavaName g csym)  (if special 
                                                then [rawinst] 
                                                else if isArrayClass csym 
                                                    then [array, instjt]
                                                    else [instjt])
         gvars = targs g sym.typ
         -- this  = Constr instName (map (TArg . _.var) gvars)
         attr  = attrs (if special 
                            then [JRawTypes, JPublic, JFinal, JStatic] 
                            else [JPublic, JFinal, JStatic])
         constructor = JConstr {attr = attrs [JPublic],
                                 jtype = jtype,  
                                 args = constrargs,
                                 body = JBlock (take (length constraints)
                                                 (zipWith JAssign
                                                     (map JAtom (getCtxs g))
                                                     (map JAtom (getArgs g))))}

         make
            | null constrargs, special = [JMethod{
                            attr = attrs [JUnchecked, JPublic, JFinal, JStatic], 
                            gvars = [JTVar{var="r", bounds=UNBOUNDED}], 
                            jtype = Constr (symJavaName g csym) [TArg "r"], 
                            name  = "mk", 
                            args  = [], 
                            body = JBlock{stmts = stmtssp}}]
            | null constrargs, not (null gvars) =
                    [JMethod{attr = attrs [JUnchecked, JPublic, JFinal, JStatic],
                             gvars,
                             name = "mk",
                             args = [], 
                             jtype = jtype.{gargs},
                             body = JBlock{stmts}}]
            | otherwise = []
            where
                gargs = map (TArg . _.var) gvars 
                stmts = [JReturn (JCast jtype.{gargs} (JAtom "it"))]
                stmtssp = [JReturn (JCast (Constr (symJavaName g csym) [TArg "r"]) (JAtom "it"))]

         singleton
            | null constrargs = [JMember{attr = attrTop, jtype = jtype.{gargs}, name="it", 
                init = Just (JNew jtype [])}]
            | otherwise = []
            where
                gargs = take (length gvars) wilds

     -- check for implementation restriction
     let k = kArity csym.tau.kind
         kindJT = (("frege.run."++) . show . _.jname . rawType $ Kinded k [])
         jt = head etype.gargs
         implementationRestriction = not special && isHigherKindedClass csym && not (implementsKinded g k jt)
     when (implementationRestriction) do
        case jt of
            Nativ{typ} | not (subTypeOf g typ kindJT) = E.error sym.pos (
                 text "The type " <+> nicest g sym.typ.rho.{context=[]} 
                 <+> text "cannot be an instance of" <+> text (csym.name.nicer g)
                 <+/> text "because the corresponding java type" <+> text jt.show
                 <+/> text "does not implement the interface " <+> text kindJT <> text ","
                 <+/> text "which would be required to represent it as higher kinded type."
              )
            other → E.error sym.pos (
                text "Implementation restriction: the type" 
                <+> nicest g sym.typ.rho.{context=[]}
                <+> text "cannot be an instance of" <+> text (csym.name.nicer g)
                <+/> text "because attempting to represent"
                <+/> text "it as a higher kinded type" 
                <+/> text "results in the invalid java type" <+> text jt.show <> text ". "
                <+/> text "To be valid," <+> text (show k) 
                    <+> text "wild card type argument(s)"
                    <+> text "should appear from the right,"
                    <+> text "but this is not the case here. "
                <+/> text "Maybe this can be corrected by"
                    <+> text "re-arranging type arguments."
                </> text "Also, if this was a newtype, it'll probably help to change it to data."   
              )
     when (isHigherKindedClass csym) do
        E.logmsg TRACEG sym.pos (text "instanceCode" <+> text (csym.name.nicer g) <+> text jt.show)
     
     instFuns  <- mapM (instFun csym sym) (if implementationRestriction then [] else methods)
     instImpls <- mapM (varCode empty)    (if implementationRestriction then [] else vals)
     
     let result = JClass {attr, 
                          name = instName.base,
                          gvars,
                          extend = Nothing,
                          implement = [etype],
                          defs = (constructor {-: callMethod this : asThunkMethod this-} : constraints)
                             ++ singleton
                             ++ make
                             ++ instFuns
                             ++ concat instImpls}     
     pure [JComment (nice sym g ++ " :: " ++ nice sym.typ g), result]

--- If given something else than a type class this is a fatal compiler error
instanceCode sym = do
    g ← getST
    E.fatal sym.pos (
            text "instanceCode: argument is "
            <+> text (nice sym g) 
        ) 

instFun :: Symbol → Symbol → QName → StG JDecl
instFun symc symi mname = do
        g       ←  getST
        sym     ←  findV mname
        let classnames = symc.name:symc.supers
            special = isSpecialClass symc
            cmems = [ m | cln ← classnames, SymC{env} ← g.findit cln, 
                            m ← env.lookupS mname.base ]
        case cmems of
            [] → E.fatal symi.pos (text "trying to instFun " <+> text (nicer mname g)
                        <+> text " but no class member found.")
            cmem:_ → do 
                -- replace symc with class where method was introduced
                symc ← findC cmem.name.tynm
                E.logmsg TRACEG symi.pos (text "instFun" <+> text (nicer sym g)
                    <+> text "for" <+> text (nicer cmem g))
                -- We need to tweek the types a bit so that java type variables won't conflict.
                -- hypothetical scenario
                -- class C a where
                --     op :: forall a b c. (C a, X b) => a -> b -> c
                -- instance Y b => C (T a b)
                --     op :: forall a b c d. (X c, Y b) => T a b -> c -> d
                -- 1. rename all type variables in C.op that also occur in the instance head
                --     op :: forall x y c. (C x, X y) => x -> y -> c
                -- 2. rename also the class variable
                --     x
                -- 3. remove the class context and the class variable from bounds of (1)
                --     op :: forall y c. (X y) => x -> y -> c
                -- 4. substitute the type for renamed class variable in (3) and add the type vars
                --     op :: forall a b y c. (X y) => T a b -> y -> c
                -- 5. prepend the additional constraints to the ones in (4)
                --     op :: forall a b y c. (Y b, X y) => T a b -> y -> c
                E.logmsg TRACEG symi.pos (
                        text (nicer sym.name g) <+> text " :: " <+> text (nicer sym.typ g)
                        </> text (nicer cmem.name g) <+> text " :: " <+> text (nicer cmem.typ g)
                    )
                let otvs = filter ((`elem` symi.typ.vars) . Tau.var) cmem.typ.tvars
                    orep = filter (`notElem` (cmem.typ.vars)) (allBinders g)
                    substBound :: TreeMap String Tau -> [Tau] -> [Tau]
                    substBound subst xs = map (\tv -> maybe tv _.{kind=tv.kind} (lookup tv.var subst)) xs
                    subst1 = Map.fromList [ (tv.var, tv.{var=s}) | (s,tv) ← zip orep otvs]
                    typ1 = ForAll (substBound subst1 cmem.typ.bound) (substRho subst1 cmem.typ.rho)
                E.logmsg TRACEG symi.pos (
                        text "(1) renamed type :: "
                        <+> text (nicer typ1 g)
                    )

                let cvar = substTau subst1 symc.tau
                    withoutCVar  = filter ((!=) cvar.var . _.var)
                E.logmsg TRACEG symi.pos (
                        text "(2) class var is now " <+> text (nicer cvar g)
                    )
                let classCtx = Ctx {pos=symc.pos, 
                                    cname = symc.name, 
                                    tau = cvar }
                    ctxs = filter (not . sameCtx classCtx) typ1.rho.context
                    rho3 = typ1.rho.{context = ctxs}
                    bound3 = withoutCVar typ1.bound
                    typ3  = ForAll bound3 rho3
                E.logmsg TRACEG symi.pos (
                        text "(3) remove class var and class context from (1) :: "
                        <+> text (nicer typ3 g)
                    )
                let jty3 = rhoJT g typ3.rho
                E.logmsg TRACEG symi.pos (
                        text "(3j) java type of (3) :: " <+> text (show jty3))
                
                let othertv = head (filter ((!=cvar.var) . _.var) typ3.tvars)
                    instTau0 = (tauRho symi.typ.rho.{context=[]}).tau
                    instTau  = if special then TApp instTau0 othertv else instTau0
                    subst4 = Map.singleton cvar.var instTau
                    methty = ForAll typ3.bound (substRho subst4 typ3.rho)
                    raw    = if special then rawType else id
                    instjt = (boxed . lambdaType . tauJT g) instTau
                    -- jsubst = Map.singleton cvar.var instjt
                    jsubstr = Map.singleton cvar.var (raw instjt)
                    jty4m   = boxed (rhoJT g methty.rho) -- substJT jsubstr (lambdaType jty3)
                    jty4r   = substJT jsubstr  (lambdaType jty3)
                E.logmsg TRACEG symi.pos (
                        text "(4) substitute instantiated type "
                        <+> text (nicer instTau g)
                        <+> text " into (3) :: "
                        <+> text (nicer methty g)
                    )
                E.logmsg TRACEG symi.pos (
                        text "(4j) java type for (4) :: "
                        <+> text (show jty4m)
                    )
                E.logmsg TRACEG symi.pos (
                        text "(4r) substitute instantiated raw type "
                        <+> text (show (raw instjt)) 
                        <+> text " into (3j) :: "
                        <+> text (show jty4r)
                    )

                let jty4 = mergeKinded jty3 jty4m
                E.logmsg TRACEG symi.pos (
                        text "(4m) merge of (3j) and (4j) :: "
                        <+> text (show jty4))

                let fakety = methty.{
                                bound ← (symi.typ.bound ++),  -- make sure ctx will be recognized
                                rho ← _.{context ← (symi.typ.rho.context++)}}
                E.logmsg TRACEG symi.pos (
                        text "(5) add extra contexts to (4), final type :: "
                        <+> text (nicer fakety g)
                    )

                fjty4  = flatFunc jty4
                fjty4r = if special then flatFunc jty4r else fjty4

                let (_, sigs)    = returnType methty.rho
                    ari          = length sigs
                    retJT
                        | Just _  ← isPrimitive o  = strict t
                        | otherwise = t 
                        where
                            t = case drop ari (dropConstr fjty4) of
                                    []  → Something --??
                                    xs  → foldr (\a\f → Func  [a,f]) (last xs) (init xs)
                            o = (tauJT g . fst . returnType) typ3.rho
                    retJTr
                        | Just _  ← isPrimitive o  = strict t
                        | otherwise = t 
                        where
                            t = case drop ari (dropConstr fjty4r) of
                                    []  → Something --??
                                    xs  → foldr (\a\f → Func [a,f]) (last xs) (init xs)
                            o = (tauJT g . fst . returnType) typ3.rho

                -- build & compile expression
                uids ← replicateM ari uniqid
                let args = take ari (getArgs g)
                    atoms = map JAtom args
                    vbls = [ Vbl{pos=symi.pos.change VARID base, name=Local{uid, base}, typ=Just typ} 
                                | (uid,typ,base) ← zip3 uids sigs args]
                    fun = Vbl{pos=symi.pos.change VARID sym.name.base, name=sym.name, 
                                typ = Just methty.{bound=[]}}
                    ex = fold app fun vbls
                        where app a b = App{fun=a, arg=b, typ=fmap (reducedSigma g) a.typ}
                    -- manipulate function arg types and make args Kinded if so in original type 
                    gargs0 = case jty4r of               -- declared type of the arguments
                                Func{gargs} → fjty4r
                                other       → []
                    gargs1 = case jty4 of                -- the real type of the arguments
                                Func{gargs} → fjty4
                                other       → []
                    nctxs  = length methty.rho.context
                    cgargs = take nctxs gargs0
                    vgargs = drop nctxs gargs0
                    rgargs = drop nctxs gargs1
                    binds = Map.fromList [(uid, (newBind g sig atom).{jtype=lazy jt})
                                | (uid, sig, atom, jt) ← zip4 uids sigs atoms vgargs ]
                    rawbinds
                        | special = Map.fromList [(uid, (newBind g sig ex).{jtype=lazy jt})
                                | (uid, sig, atom, jt) ← zip4 uids sigs atoms rgargs,
                                                        ex = JCast (lazy jt) (JCast Something atom) ]
                        | otherwise = binds
                result ← compiling sym.{typ=fakety} (genExpr true retJT ex rawbinds)
                let rex
                        | special, retJTr != retJT = JCast retJTr (JCast Something result.jex)
                        | otherwise                = result.jex

                    unchecked :: JExpr -> Bool
                    unchecked ex = case ex of
                        JAtom{name} →  false
                        JNew{jt, args} →  any unchecked args
                        JNewClass{jt, args, decls} →  false
                        JLambda{fargs, code} →  false
                        JNewArray{jt, jex} →  unchecked jex
                        JInvoke{jex, args} →  unchecked jex || any unchecked args 
                        JStMem{jt, name, targs} →  false
                        JExMem{jex, name, targs} →  unchecked jex
                        JCast Something _   → true
                        JCast{jt, jex} →  unchecked jex
                        JUnop{op, jex} →  unchecked jex
                        JBin{j1, op, j2} →  unchecked j1 || unchecked j2
                        JQC{j1, j2, j3} →  unchecked j1 || unchecked j2 || unchecked j3
                        JArrayGet{j1, j2} →  false

                -- finally make the function
                pure JMethod{attr = if special || unchecked rex 
                                    || needsUnchecked fst cmem.name (Map.lookupDefault Something cvar.var jsubstr) 
                                        then attrs [JUnchecked, JPublic, JFinal, JOverride]
                                        else attrs [JPublic, JFinal, JOverride], 
                            gvars = targs g methty, 
                            jtype = retJTr, 
                            name = latinF ++ (symJavaName g sym).base, 
                            args = [(attrFinal, pSigma, ctx, name) 
                                       | (ctx,name) ← zip cgargs 
                                                          (drop (length symi.typ.rho.context) 
                                                                (getCtxs g)) ]
                                    ++ [(attrFinal, pSigma, lazy jt, name)
                                        | (jt,name) ← zip vgargs args], 
                            body = JBlock{stmts=[JReturn rex]}}
